from abc import ABC, abstractmethod

import taichi as ti

from .utils import Vec3, Ray3, IdGenerator
from .color import Colorer, PureColor
from .config import MAX_MATERIAL, MAX_MATERIAL_ATTRIBUTE

__all__ = [
    "Material",
]


@ti.dataclass
class MaterialUtil:
    """材质的底层储存结构"""

    typeId: int
    color: int
    attribute: ti.types.vector(MAX_MATERIAL_ATTRIBUTE, float)


class Material(ABC):
    """材质的操作接口"""

    typeId: int = None
    default: "Material"
    asEnd: bool = False
    materials: ti.StructField = MaterialUtil.field(shape=MAX_MATERIAL)
    materialsNum: ti.ScalarField = ti.field(int, 1)
    used: "dict[int, type]" = {}
    getId = IdGenerator()

    def __init__(self) -> None:
        if Material.materialsNum[0] >= MAX_MATERIAL:
            raise MemoryError("number of Materials out of design.")
        if type(self) not in self.used.values():
            type(self).typeId = next(self.getId)
            self.used[self.typeId] = type(self)
        self.index: int = Material.materialsNum[0]
        Material.materialsNum[0] += 1
        Material.materials[self.index].typeId = self.typeId
        self.setColor(Colorer.default)

    def setColor(self, r: float = None, g: float = None, b: float = None):
        """设定材质的颜色(反照率)"""
        if not isinstance(r, Colorer):
            r = PureColor(r, g, b)
        Material.materials[self.index].color = r.index
        return self

    @abstractmethod
    def scatter(self: int, ray: Ray3, hitPoint: Vec3, norm: Vec3):
        """处理材质的散射, 返回是否结束追踪、颜色和下一个采样光线"""

    @ti.func
    def getColor(self: int, hitPoint: Vec3) -> Vec3:
        return Colorer.handleAllColorer(Material.materials[self].color, hitPoint)

    @ti.func
    def handleAllScatter(self: int, rayObj: Ray3, hitPoint: Vec3, norm: Vec3):
        """处理所有材质的散射过程"""
        typeId = Material.materials[self].typeId
        end, color, ray = False, Vec3(0), Ray3()
        for material in ti.static(tuple(Material.used.values())):
            if typeId == material.typeId:
                end = material.asEnd
                color, ray = material.scatter(self, rayObj, hitPoint, norm)
        return end, color, ray
